
rm(list=ls())
library(doParallel)
library(doRNG)
library(foreach)

library(abind)

options(error=stop)
source("2.0.0 MyFunc.R")
source("2.0.1 GetProb.R")
source("2.0.3 GetContrast.R")
source("2.1 DataGen.R")
source("2.1.2 MLEst3.R")
source("2.2.3 PMLEst4.R")

SEED.index = 1 # DEBUG

### Start of true values specification
#   parameters
n.cand      = c(100, 200, 500, 5000)
numsims     = 500
#### Parameters that has statble properties
# alpha.true  = t(matrix(rep(c(0,-0.5),5),2,5))
alpha.true  = t(matrix(rep(c(0,0.7),5),2,5)) #CHANGE
# alpha.true  = t(matrix(rep(c(-0.5,1),5),2,5)) #CHANGE
# beta.true   = t(matrix(rep(c(-0.5,0.1),5),2,5))
beta.3 =  c(-0.5, 1)
beta.1 = beta.2 = beta.4 = beta.5 =c(-0.5, 0.1);
beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))

### Parameters to illustrate g-null paradox
# alpha.true  = t(matrix(rep(c(0,0),5),2,5))
# beta.3 =  c(-0.5, 0.1)
# beta.1 = c(-0.5, -1); beta.2 = c(-0.5, 1.5)
# beta.4 = c(-0.5, 1.5); beta.5 = c(-0.5, -2)
# beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))

# datagen = DataGen(5000, SEED=seeds[seed.index])
# MyAttach(datagen)
# cnt

gamma.true  = t(matrix(c(0.1,-0.5,NA,NA,0.1,-0.5,0.1,-0.5),4,2))
est.method.cand = c("MLE3","PMLE4","DR")
params.true = list(alpha.true=alpha.true, beta.true=beta.true,
                   gamma.true=gamma.true)
rownames(gamma.true) = c("a0","a1")

### To store results
sim.name = paste("sim=",1:numsims,sep="")
n.name   = paste("n=",n.cand,sep="")
est.method.name = as.character(est.method.cand)
results  = array(dim=c(numsims,length(n.cand), length(est.method.cand),
                       nrow(alpha.true), ncol(alpha.true),2,4),
                 dimnames = list(sim.name,n.name,est.method.name,
                                 1:nrow(alpha.true), 1:ncol(alpha.true),
                                 c("alpha","beta"),c("est","nll","time",
                                                     "contrast")))
names(dimnames(results)) = c("Simulation number","Sample size","est.method",
                             "1-5","1/2","Param","Measure")

MESSAGE <-- FALSE     # compress message (This is a global variable!)

ptm0 = proc.time()


{
    n.index = 4
    
    ptm1 = proc.time()
    n = n.cand[n.index]
    
    cat("Sample size: ",n,"\n")
    set.seed(n)
    seeds       = runif(10000,-2^30,2^30)
    {
        ee.assume.true.alpha = T; 
        
        # Nuisance model for y.hat.type for DR estimation; 
        y.hat.type = "logit" # "saturated" or "MLE" or "logit" or "diy"
        prop.score.type = "correct" # "correct" or "incorrect"
        vectorize.get.prob = F;parallel.root = F
        l0.type = "discrete" # "continuous" or "discrete"
        l1.type ="correct" # "correct" or "incorrect" or "diy"
        gop.type = "correct" # "correct" or "incorrect"
    }
    
    {
        doMLE=0; doPMLE=1; 
        doDR=1; 
        doGComp=1;
        doDebugEE=0
    }
    {
        source("2.5 CallDR.R")
        source("2.5.1 DR.R")
        source("2.5.1 DR_Iter.R")
        source("2.5.3 DR_MyFunc.R")
        source("2.6 DR_EE_value.R")
    }
    
    GComp.res = array(dim=c(numsims, 8))
    {ee.res = array(dim=c(numsims,5*2))
        ee.stage1 = ee.d1 = ee.d0 =array(dim=c(numsims,5*2))
        ee.u1 = ee.u0 =  numeric(numsims)
        # j = 1; vals = list() # to output test value
    }
    # numsims = 500
    M = 5; per.iter = numsims/M
    
    ### 25 Jul 2021: run 10 times DR estimation.
    dr.res.10 = array(dim=c(10, dim(results[,4,"PMLE4",,,,])))
    dimnames(dr.res.10) = list(Initial = 1:10, sim.name, 1:nrow(alpha.true), 1:ncol(alpha.true),
                                c("alpha","beta"),c("est","nll","time",
                                                    "contrast"))
    names(dimnames(dr.res.10) )[2:6] = c("Simulation Number", "1-5","1/2","Param","Measure")
} 
t0.total = proc.time()
for (iter in 1:M){
    # iter = 1
    boot.list = ((iter-1)*per.iter+1):(iter*per.iter)
    # for(i in boot.list){
    boot1 = foreach(i = boot.list,.errorhandling="stop") %dopar% {
        # i= 1 #DEBUG
        
        seed.perturb = 0
        seed.index = i + (SEED.index-1) * numsims +seed.perturb
        set.seed(seeds[seed.index])
        cat("\t Simulation number: ",seed.index-seed.perturb,"\n")    
        
        ### 0: Data generation: x, y
        source("2.1 DataGen.R")
        
        datagen = DataGen(n, SEED=seeds[seed.index])
        MyAttach(datagen)
        
        ### 1 = Maximal likelihood estimate
        # library(profvis) # DEBUG
        # # library(bench)
        # vectorize.get.prob = T;parallel.root = F
        
        # profvis(MLEst3(data, cnt))
        mle3 = NULL
        if (doMLE){
            set.seed(seeds[seed.index])
            t0 = proc.time()
            mle3 = MLEst3(data, cnt)
            t1 = proc.time()
            print(paste0("MLE time: ", (t1 - t0)[3], " secs")) # 403.982 seconds
            results[i,n.index,"MLE3",,,,] = mle3
        }
        
        pmle4 = NULL
        source("2.0.1 GetProb.R")
        if (doPMLE){
            set.seed(seeds[seed.index])
            t2 = proc.time()
            pmle4 = PMLEst4(data, cnt)
            results[i,n.index,"PMLE4",,,,] = pmle4
            t3 = proc.time()
            t3 - t2 # 360.650 
            print(paste0("PMLE time: ", round(as.numeric(t3 - t2)[3],2), " secs")) 
        }
        (alpha.pmle = pmle4[,,1,1])
        (beta.pmle = pmle4[,,2,1])
        
        
        ### 2 = DR Estimate based on Vansteelandt and Joffe (2014)
        #   2.1 Estimate gamma propensity score model
        if (prop.score.type == "correct"){
            gamma.a0 = glm.fit(a0.cov,a0.vec,family=binomial(),intercept=FALSE)$coeff
            gamma.a1 = glm.fit(a1.cov,a1.vec,family=binomial(),intercept=FALSE)$coeff
        }else{
            a0.cov.incor = a0.cov; a1.cov.incor = a1.cov
            a0.cov.incor[,2] = 0; a1.cov.incor[,2:ncol(a1.cov.incor)] = 0
            gamma.a0 = glm.fit(a0.cov.incor,a0.vec,family=binomial(),intercept=FALSE)$coeff
            gamma.a1 = glm.fit(a1.cov.incor,a1.vec,family=binomial(),intercept=FALSE)$coeff
            gamma.a0[2] = 0
            gamma.a1[2:ncol(a1.cov.incor)] = 0
        }
        gamma = t(matrix(c(gamma.a0,NA,NA,gamma.a1),4,2))
        rownames(gamma) = c("a0","a1")
        #   2.2 Estimate outcome regression for y
        {source("2.5 CallDR.R")
            source("2.5.1 DR.R")
            source("2.5.1 DR_Iter.R")
            source("2.5.3 DR_MyFunc.R")
            source("2.0.0 MyFunc.R")
            source("2.0.1 GetProb.R")}
        
        
        
        
        
        dr.sol = NULL
        if (doDR){
            if (doPMLE==F){
                set.seed(seeds[seed.index])
                pmle4 = PMLEst4(data, cnt)
                results[i,n.index,"PMLE4",,,,] = pmle4
                # pmle4 = results[i,n.index,"PMLE4",,,,] #DEBUG
            }
            
            # (alpha.pmle = mle3[,,1,1])
            # beta.pmle = mle3[,,2,1]
            (alpha.pmle = pmle4[,,1,1])
            beta.pmle = pmle4[,,2,1]
            # alpha.pmle = alpha.true
            # beta.pmle = beta.true
            
            if (gop.type == "incorrect"){
                tmp = rm.l0.tilde(data, cnt)
                data = tmp$data
                cnt = tmp$cnt
                source("2.5.3 DR_MyFunc.R")
                nb = 32
                vec.1 = rep(1,nb)
                vec.0 = rep(0,nb)
            }
            
            if (y.hat.type == "saturated"){
                # Use saturated model
                y.hat = cnt[17:32]/(cnt[17:32]+cnt[1:16])
            }else if (y.hat.type == "MLE"){
                # Use the invert map
                DataGen.obj = func.DataGen(c(alpha.pmle,beta.pmle,gamma.true), data,cnt)
                y.hat = DataGen.obj$p.y[1:16]
            }else if (y.hat.type == "logit"){
                # Use logistic regression
                y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
            }else if (y.hat.type == "diy"){
                y.hat = c(0.5,0,0,0,0)
            }
            
            if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
                l1.cov = y.cov[,c("l0.1","l0.2","a0")]
                l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
                
            }
            
            ####
            # Construct the warm up by the logit
            ####
            # First stage RR using g-formula
            # beta.pmle = NULL
            log.RR.logit = RR.logit = alpha.logit = array(0, c(5,2))
            for (l0 in c(0,1)){
                l0.vec = c(1,l0)
                numer = y.hat.func(l0=l0.vec,a0=1,l1=1,a1=0,y.hat) * p.l1.func(a0=1,l0.vec,beta.pmle, nb,l1.hat)+
                    y.hat.func(l0=l0.vec,a0=1,l1=0,a1=0,y.hat) * (1-p.l1.func(a0=1,l0.vec,beta.pmle, nb,l1.hat))
                denom = y.hat.func(l0=l0.vec,a0=0,l1=1,a1=0,y.hat) * p.l1.func(a0=0,l0.vec,beta.pmle, nb,l1.hat)+
                    y.hat.func(l0=l0.vec,a0=0,l1=0,a1=0,y.hat) * (1-p.l1.func(a0=0,l0.vec,beta.pmle, nb,l1.hat))
                RR.logit[1, l0+1] = numer/denom
            }
            # Second stage blips
            for (a0 in c(0,1)){
                for (l1 in c(0,1)){
                    for (l0 in c(0,1)){
                        l0.vec = c(1,l0)
                        numer = y.hat.func(l0=l0.vec,a0=a0,l1=l1,a1=1,y.hat)
                        denom = y.hat.func(l0=l0.vec,a0=a0,l1=l1,a1=0,y.hat)
                        RR.logit[2+2*(1-a0)+(1-l1), l0+1] = numer/denom
                    }
                }
            }
            log.RR.logit = log(RR.logit)
            log.RR.logit[,2] = log.RR.logit[,2] - log.RR.logit[,1] # separate the intercept
            alpha.logit = log.RR.logit
            ###########
            ### Store alpha.logit to result "PMLE"
            ##########
            # results[i,n.index,"PMLE4",,,1,1] = alpha.logit #THIS CANNOT MAKE IT
            pmle4[,,1,1] = alpha.logit
            
            #   2.3 Estimate alpha (beta and p0 are given by MLE)
            # j = 1; vals = list() # to output EE value
            t4 = proc.time()
            set.seed(seeds[seed.index])
            results[i,n.index,"DR",,,,] = dr.sol =
                DREst(data,cnt,alpha.logit,beta.pmle,gamma, y.hat,l1.hat, 
                      SEED = seeds[seed.index], dr.warm="DIY", dr.warm.obj=alpha.logit)
            t5 = proc.time()
            
            
            print(paste0("DR time: ", t5 - t4))
            # plot(unlist(vals),type='l',
            #      main="EE when alpha,beta use pmle, with saturated OR",
            #      xlab='optim iteration', ylab="EE value")
            # abline(1,0,col="red")
            
            (alpha.dr = dr.sol[, , 1, 1])
            (beta.dr = dr.sol[, , 2, 1])
            # (contrast.dr = dr.sol[, , 1, 4])
            
            source("2.6 DR_EE_value.R")
            
            # for(j in 1:10){
            #     # j=2
            #     dr.warm = dr.warm.obj = NULL
            #     if (j == 1){
            #         dr.warm="DIY"
            #         dr.warm.obj = alpha.logit
            #     }else{
            #         dr.warm = "random"
            #     }
            #     dr.res.10[j,i,,,,]= dr.est =DREst(data,cnt,alpha.logit,beta.pmle,gamma, y.hat,l1.hat, 
            #               SEED = seeds[seed.index]+j, dr.warm=dr.warm, dr.warm.obj=dr.warm.obj)
            #     
            #     (alpha.dr = dr.est[, , 1, 1])
            #     (beta.dr = dr.est[, , 2, 1])
            #     ee.val = EE.value(data, cnt, alpha.dr,beta.dr, gamma,
            #                        Hessian=FALSE,y.hat)
            #     dr.res.10[j,i,,,,"nll"] = nll.j = sum(ee.val$tmp^2)
            #     if (nll.j < 0.005){ # Here point.dr.est.MLE$nll is the DR objective value
            #         # Copy and paste this result and stop
            #         if (seed.start < 10){
            #             for (k in (j+1):10){
            #                 dr.res.10[k,i,,,,]= dr.res.10[j,i,,,,]
            #             }
            #         }
            #         break
            #     }
            # }
        }
        
        ############
        GComp.res.i = GComp.res[i,]
        if (doGComp){
            ### Do g-computation for E[Y(a1,a2)|L0=l0]
            # alpha.pmle = pmle4[,,1,1]
            beta.pmle = pmle4[,,2,1]
            y.hat.type = "logit" # "saturated" or "MLE" or "logit" or "diy"
            if (y.hat.type == "logit"){
                # Use logistic regression
                # y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
                # 21Jul2021 change: should use y.cov.tilde
                y.hat = glm.fit(y.cov.tilde,y.vec,family=binomial(),intercept=FALSE)$coeff
            }   
            y.hat.func = function(l0,a0,l1,a1,y.hat){
                if (class(l0) == "numeric"){
                    res = expit(c(l1,a0,l0,a1) %*% y.hat)
                }else{
                    res = c(expit(cbind(l1,a0,l0,a1) %*% y.hat))
                }
                return(res)
            }
            
            l1.type ="incorrect" 
            if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
                l1.cov = y.cov[,c("l0.1","l0.2","a0")]
                l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
                # 21Jul2021 change: should use y.cov.tilde
                # l1.cov.tilde = y.cov[,c("l0.1","l0.2.tilde","a0")]
                # l1.hat = glm.fit(l1.cov.tilde,l1.vec,family=binomial(),intercept=FALSE)$coeff
            }
            p.l1.func= function(a0,l0,beta, nb,l1.hat="holder"){
                if (class(l0) == "numeric"){
                    res = expit(c(l0, a0) %*% l1.hat)
                }else if (class(l0)=="matrix"){
                    # res = sapply(1:nb, function(i) {expit(c(l0[i,], a0[i]) %*% l1.hat)})
                    res = c(expit(cbind(l0, a0) %*% l1.hat))
                }
                return(res)
            }
            
            for (l0 in c(0,1)){
                for (a0 in c(0,1) ){
                    for (a1 in c(0,1) ){
                        l0.vec = c(1,l0)
                        GComp.res[i,4*l0 + 2*a0 + a1 + 1] =GComp.res.i[4*l0 + 2*a0 + a1 + 1] = 
                            y.hat.func(l0=l0.vec,a0,l1=1,a1,y.hat) * p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat)+
                            y.hat.func(l0=l0.vec,a0,l1=0,a1,y.hat) * (1-p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat))
                        # print(GComp.res.i[4*l0 + 2*a0 + a1 + 1])
                    }
                }
            } # End of l0 loop
        } # End of doGComp
        if (doDebugEE){
            source("2.6 DR_EE_value.R")
            ee.val =
                EE.value(data, cnt, alpha.dr,beta.dr, gamma,
                         Hessian=FALSE,yhat)
            ee.u1[i] = ee.val$tmp.u1
            ee.u0[i] = ee.val$tmp.u0
            for (r.idx in 1:5){
                for (c.idx in 1:2){
                    ee.res[i,2*(r.idx-1) + c.idx] = ee.val$tmp[r.idx, c.idx]
                    ee.stage1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.stage1[r.idx, c.idx]
                    ee.d1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d1[r.idx, c.idx]
                    ee.d0[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d0[r.idx, c.idx]
                    
                }
            }
        }
        list(mle3=mle3, pmle4=pmle4, dr.sol=dr.sol, GComp.res.i=GComp.res.i)
    } # End of i = 1:numsims loop
    for (i in boot.list){
        tmp = boot1[[i-(iter-1)*per.iter]]
        if (doMLE){
            results[i,n.index,"MLE3",,,,] = tmp$mle3
        }
        if (doPMLE){
            results[i,n.index,"PMLE4",,,,] = tmp$pmle4
        }
        if (doDR){
            results[i,n.index,"DR",,,,]  = tmp$dr.sol
        }
        if (doGComp){
            GComp.res[i, ] = tmp$GComp.res.i
        }
        
    }
    save(results, GComp.res, params.true, file=paste("./RDataFiles/11Jun2021-parallel-n-1000-cts-l0-theta-not-0-l1-logit-iter-",iter,".RData",sep=""))
} # End iter
t1.total = proc.time()
t1.total - t0.total
#############
# Unlist parallel result
##############



# mean.res = results[1,n.index,"PMLE4",,,,]
# n.index = 4




curr.sim = numsims
mean.res = GComp.res[1,]
mean.res.pmle = results[1,n.index,"PMLE4",,,,]
mean.res.DR = results[1,n.index,"DR",,,,]
for(i in 2:curr.sim){
    
    mean.res.pmle = mean.res.pmle +results[i,n.index,"PMLE4",,,,]
    mean.res.DR = mean.res.DR + results[i,n.index,"DR",,,,]
    mean.res = mean.res + GComp.res[i,]
}

(mean.res.pmle/curr.sim)[, , 1, 1]
(mean.res.DR/curr.sim)[, , 1, 1]
mean.res/curr.sim

wide = GComp.res[1:curr.sim,]
colnames(wide) = c("000","001","010","011",
                    "100","101","110","111")
wide = data.frame(wide)
library(data.table)
long = melt(setDT(wide),measure.vars = 1:8, variable.name = "l0a0a1")
# cond.g.null.aov = aov(value~l0a0a1, data=long[1:(4*curr.sim),])
# summary(cond.g.null.aov)
marginal.data = data.frame(value=0.5*(long[1:(4*curr.sim),'value']+
                                           long[(1+4*curr.sim):(8*curr.sim),'value']),
                            a0a1 = rep(c("00","01","10","11"), each=curr.sim))
marginal.g.null.aov = aov(value~a0a1, data=marginal.data) # because L0 is Bern(1/2)
summary(marginal.g.null.aov)
print(paste("This sample size takes",round((proc.time() - ptm1)[3],2),
            "s"))

# } # End for n.index in 1:length(n.cand))

rownames(gamma.true) = c("a0","a1")
datagen = DataGen(100000, SEED=0)
MyAttach(datagen)
params.true$contrast.true = GetContrast(data, alpha.true, beta.true)
rownames(gamma.true) = c("a0","a1")

print(paste("It takes",round((proc.time() - ptm0)[3],2),"s"))



save(results, GComp.res, params.true, dr.res.10, file=paste("./RDataFiles/8Jul2021-n-1000-theta-not-0-g-comp-y-logit",SEED.index,".RData",sep=""))
# save(results, GComp.res, params.true, file=paste("./RDataFiles/15Jul2021-n-1000-theta-0-g-comp-y-misgop",SEED.index,".RData",sep=""))

cluster.path = "~/Desktop/causal res/SNMM_Local/3. Data and Programming/cluster/080820"
setwd(cluster.path)
load(paste("./RDataFiles/15Jul2021-n-1000-theta-0-g-comp-y-misgop",1,".RData",sep=""))

